Securing ML APIs with FastAPI

api
data science
machine learning
programming
security
fastapi
python
Author

Adejumo Ridwan Suleiman

The end goal of a machine learning model is to serve end users. Still, due to machine learning models requiring regular updates to improve model accuracy and use in other applications, they are exposed as an API. An ML API is an application that serves as a gateway between your client requests or needs and your machine learning model.

Let’s say you have a recommender model on an e-library platform, that recommends books for users based on user preferences. This recommender model works as an API by getting user preferences and recommending books to the user. The API also makes it easy for you to use the recommender model on another platform.

Due to the sensitivity of training data in machine learning models, API security is important to avoid data breaches and prevent malicious clients from accessing the model. In this article, I will show you how to secure your machine-learning APIs using FastAPI - an open-source Python framework that allows you to build secured and scalable APIs. As a Python library, the learning curve is low for data scientists and machine learning engineers with Python backgrounds. If you are new to FastAPI check out this course on ML deployment with FastAPI.

ML Model API Workflow. Image by Author

Fundamentals of API Security

API is usually a target for data breaches and unauthorized access due to the information it contains, making it prone to security attacks, this is why API security is important. API security is a practice set to protect an API from unauthorized access. Here are some of the most common API security threats:

  • Injection attacks (SQL, command): In this type of attack, someone injects malicious code into the API, using SQL or terminal commands to read or modify the database. These kinds of attacks are usually targeted at the application’s database.
  • Cross-site scripting (XSS): This is another type of attack where a hacker manipulates a vulnerable site by sending malicious JavaScript to users, which upon execution by a user, the attacker can masquerade as the user and manipulate the user’s data.
  • Cross-site request forgery (CSRF): In this attack, attackers make users perform actions they don’t intend to do.
  • Man-in-the-middle (MITM) attacks: In this attack, hackers eavesdrop between the interaction of clients and the API, to steal relevant credentials such as login details and credit card information.

In this article, you will learn how to solve these issues and make your machine-learning API secure.

Prerequisites

Setting Up FastAPI for ML APIs

  1. Create a project folder and a virtual environment.

  2. Copy and paste the following code into a new file called utilis.py in your project directory. This will create a classification model and a model.pkl file based on the iris dataset.

    from sklearn.datasets import load_iris
    from sklearn.ensemble import RandomForestClassifier
    import joblib
    
    # Load the iris dataset
    iris = load_iris()
    X, y = iris.data, iris.target
    
    # Train a random forest classifier
    model = RandomForestClassifier()
    model.fit(X, y)
    
    # Save the trained model
    joblib.dump(model, 'model.pkl')
  3. Create an API endpoint for the machine-learning model in a file main.py .

    from fastapi import FastAPI
    from pydantic import BaseModel
    import joblib
    
    # Load the trained model
    model = joblib.load("model.pkl")
    
    # Define the request body using Pydantic
    class PredictionRequest(BaseModel):
        sepal_length: float
        sepal_width: float
        petal_length: float
        petal_width: float
    
    app = FastAPI()
    
    @app.post("/predict")
    def predict(request: PredictionRequest):
        # Convert request data to a format suitable for the model
        data = [
            [
                request.sepal_length,
                request.sepal_width,
                request.petal_length,
                request.petal_width,
            ]
        ]
        # Make a prediction
        prediction = model.predict(data)
        # Return the prediction as a response
        return {"prediction": int(prediction[0])}
    
    # To run the app, use the command: uvicorn script_name:app --reload
    # where `script_name` is the name of your Python file (without the .py extension)

We now have our ML model API, let’s see how we can implement security best practices using this API.

Implementing Authentication and Authorization

Take API authentication like a passkey that allows a client to access your API, allowing only authorized users to use the API. There are various ways of implementing API authentication in FastAPI, which you will learn subsequently.

Authentication vs Authorization. Source: Medium

API authentication is insufficient to protect your API, you also need to implement API authorization. API authentication is like giving someone a key to your house, while API authorization is like giving them access to specific rooms in the house.

API Key-Based Authentication

This is the most basic and popular form of implementing API security.

  1. To implement key-based authentication in FastAPI, add the following code before the @app.post("/predict") endpoint in main.py file.

    # Define the API key
    API_KEY = "your_api_key_here"
    
    # Dependency to verify the API key
    def get_api_key(api_key: str = Header(...)):
        if api_key != API_KEY:
            raise HTTPException(status_code=403, detail="Could not validate credentials")
    • API_KEY is the variable that contains your environment key, which is supposed to be stored as an environment variable in a .env file.
    • get_api_key() function gets the API_KEY and verifies if the provided API key matches what’s on the database. If successful, the user is granted access to the API, else an HTTP error 403 is raised, telling the user that the provided credential is invalid.
  2. Next, go to the predict() function and add api_key as an argument to get the api_key from users.

    @app.post("/predict")
    def predict(request: PredictionRequest, 
                            #added argument to get API key from user
                api_key: str = Depends(get_api_key)):
    • Depends function prevents access to the /predict endpoint without the API key.

    How authentication works in FastAPI

OAuth2 with JWT Tokens

Unlike API keys, OAuth2 is an authorization protocol, granting clients access to resources hosted by other web applications on behalf of the user. With OAuth2, users do not need to give out their password to access a resource.

A practical example is a client accessing your machine learning API using their Google ID without giving away their details, and your API in turn sends a token back to the client to serve as a temporary password for the client to access the API. It’s very secure compared to the API key. Unlike the API key which grants a user access to all resources in an API, OAuth2 only grants the client access to specified resources.

When a user wants to access a machine learning API through a client application, the process typically uses OAuth2 for secure authentication. The client application starts by redirecting the user to an authentication server, where the user grants permission for the application to access their resources. The authentication server then issues an access token, often, in the form of a JWT (JSON Web Token) to the client application. The application uses this token to make requests to the machine learning API. The API verifies the token to ensure the client is authorized to access the requested resources, thus providing secure and controlled access while protecting user data and privacy.

OAuth2 Workflow: Source: GeeksforGeeks

Let’s implement a simple OAuth2 with JWT on our machine learning API, by updating the main.py file as follows.

  1. Ensure you install pyjwt and import the following Python libraries.

    from fastapi import FastAPI, Depends, HTTPException
    from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
    from pydantic import BaseModel
    import joblib
    from typing import Optional
    import jwt
    • OAuth2PasswordBearer and OAuth2PasswordRequestForm are used to implement OAuth2 in FastAPI.
    • jwt is used to create a JSON Web Token.
  2. Define the user model, to allow the user to provide a username and password using the BaseModel class.

    # Define a user model
    class User(BaseModel):
        username: str
        password: str
  3. Create a function to authenticate users.

    def authenticate_user(username: str, password: str) -> Optional[User]:
        if username == "admin" and password == "password":
            return User(username=username, password=password)
        return None
    • The authenticate_user() function takes in a client username and password to see if it matches what’s in the database and returns a User model.
  4. Create a SECRET_KEY variable to encode the JWT and create an oauth2_scheme

    SECRET_KEY = "your-secret-key"
    
    # OAuth2 scheme using password flow
    oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  5. Create a function to access the token using JWT.

    def create_access_token(data: dict):
        return jwt.encode(data, SECRET_KEY)
    • The create_access_token() function takes in the user details and encodes it with the SECRET_KEY
  6. Create an authentication route to generate the access token.

    @app.post("/token")
    async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
        user = authenticate_user(form_data.username, form_data.password)
        if not user:
            raise HTTPException(
                status_code=401,
                detail="Incorrect username or password",
                headers={"WWW-Authenticate": "Bearer"},
            )
        access_token = create_access_token({"sub": user.username})
        return {"access_token": access_token, "token_type": "bearer"}
    • The login_for_access_token()function takes the user inputs; username and password with the OAuth2 flow as an argument to return an access token to give the client application.
    • If user details are right, an access token is created and returned, else a 401 warning is returned
  7. Protect the API route that requires JWT authentication

    @app.post("/predict")
    async def predict(request: PredictionRequest, token: str = Depends(oauth2_scheme)):
        try:
            # Decode JWT token
            payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
            # Convert request data to a format suitable for the model
            data = [
                [
                    request.sepal_length,
                    request.sepal_width,
                    request.petal_length,
                    request.petal_width,
                ]
            ]
            # Make a prediction
            prediction = model.predict(data)
            # Return the prediction as a response
            return {"prediction": int(prediction[0])}
        except jwt.exceptions.DecodeError:
            raise HTTPException(
                status_code=401,
                detail="Could not validate credentials",
                headers={"WWW-Authenticate": "Bearer"},
            )
    • The argument token: str = Depends(oauth2_scheme) means the API endpoint is protected using OAuth2, and receives the access token from the client application.
    • The token is decoded to see if it contains the SECRET_KEY, if it does, access is given to the model prediction, else a warning is given stating that the provided credentials are invalid.

Key takeaways

  • User logs in and their data is encoded with a secret key to create an access token

  • The secured API endpoint decodes this access token to see if it contains the secret key before providing access to the resource.

    How OAuth2 works in FastAPI.

    How OAuth2 works in FastAPI.

Role-Based Access Control (RBAC)

RBAC is an approach where users are given various roles that provide access to specific API resources. It is an efficient way of ensuring API security, instead of granting all users privileges, users are granted privileges based on their needs in an API.

  1. Let’s implement an RBAC into the OAuth we created recently, by creating a dummy user data inside main.py.

    # Dummy user data
    users_db = {
        "admin": {"username": "admin", "password": "password", "role": "admin"},
        "user": {"username": "user", "password": "password", "role": "user"}
    }
  2. To demonstrate RBAC, admin will have access to our model prediction API endpoint while user will not have access to it. Update the User model to have a role field.

    # Define a user model with role
    class User(BaseModel):
        username: str
        password: str
        role: str
  3. Just before the API endpoint, add the following function

    # Role-based access control dependency
    def role_checker(required_role: str):
        def role_dependency(current_user: User = Depends(get_current_user)):
            if current_user.role != required_role:
                raise HTTPException(
                    status_code=403,
                    detail="Operation not permitted",
                )
            return current_user
        return role_dependency
    • The function role_checker() checks for the required role, by taking the required role admin as an argument.
    • The role_dependency() function checks if a user meets a required role, by taking the User as an argument.
    • If the user meets the required role, then the user is granted access, else a 403 status code is returned with a warning "Operation not permitted"
  4. Update the API endpoint by adding a user argument.

    @app.post("/predict")
    async def predict(request: PredictionRequest, 
                      token: str = Depends(oauth2_scheme), 
                      current_user: User = Depends(role_checker("admin"))):
    • The current_user argument ensures that no User can access an API endpoint unless given permission.

      How RBAC works in FastAPI.

Input Validation and Sanitization

Input validation involves checking all inputs in an API to ensure that they meet certain requirements, while sanitization is input modification to ensure validity. Validation checks involve checking for allowed characters, length, format, and range, at the same time, sanitization is the changing of the input to ensure it is valid, such as shortening an input, or the removal of HTML tags in an input.

Input validation and sanitization help to prevent common attacks like SQL injection and Cross-site scripting, most times you use input validation when your user is to give a particular input type, for example, a mobile number which is all digits. Sanitization is used when the user is expected to provide varying input types such as a user’s profile.

Using Pydantic for Input Validation

pydantic is a Python library that allows you to define and validate user inputs. It makes it easy to perform schema validation and serialization using type annotations. Earlier on, we used Pyndantic to validate our User and PredictionRequest.

class PredictionRequest(BaseModel):
    sepal_length: float
    sepal_width: float
    petal_length: float
    petal_width: float

class User(BaseModel):
    username: str
    password: str
    role: str

Securing Data Transmission

When exchanging data between systems, it’s important to use data transmission protocols to secure and protect the data from unauthorized access. Data transmission security ensures that only authorized users can transmit data, and protect the system from vulnerabilities. There are various protocols one can force to keep data transmission secured such as HTTPS(Hypertext Transfer Protocol Secure), TLS(Transport Layer Security), SSH(Secure Shell), and FTPS(File Transfer Protocol Secure), we will only talk about HTTPS.

Enforcing HTTPS

HTTPS is a secured version of HTTP, where the data is encrypted when data is exchanged between a client and an API. Especially, when confidential details are shared such as user login credentials or account details. Unlike HTTP which has no security layer and makes data vulnerable, HTTPS adds an SSL/TLS layer to ensure that data is encrypted and secured.

HTTPS workflow. Source: GeeksForGeeks

  1. To secure data in the API endpoint we created earlier, let’s generate a self-signed certificate for testing. Copy and paste the following code into your terminal.

    openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes

    This will generate a self-signed SSL/TLS certificate with a private key using OpenSSL.

    • openssl: This is the command-line tool for using the various cryptography functions of OpenSSL’s library.
    • req: This sub-command is used to create and process certificate requests (CSRs) and, in this case, to create a self-signed certificate.
    • x509: This option is used to generate a self-signed certificate instead of a certificate request.
    • newkey rsa:4096: This option does two things:
      • newkey: It generates a new private key along with the certificate.
      • rsa:4096: This specifies the type of key to create, in this case, an RSA key with a size of 4096 bits.
    • keyout key.pem: This specifies the file where the newly generated private key will be saved (key.pem).
    • out cert.pem: This specifies the file where the self-signed certificate will be saved (cert.pem).
    • days 365: This sets the certificate to be valid for 365 days (1 year).
    • nodes: This option ensures that the private key will not be encrypted with a passphrase. Without this option, OpenSSL would prompt for a passphrase to encrypt the private key.
  2. Provide the necessary information to create the key.pem (private key) and cert.pem (certificate).

    Generating a self-signed certificate using OpenSSL.

    Generating a self-signed certificate using OpenSSL.

  3. At the end of the main.py file, add the following code.

    import uvicorn 
    
    if __name__ == "__main__":
        uvicorn.run(
            app, host="127.0.0.1", port=8000, ssl_keyfile="key.pem", ssl_certfile="cert.pem"
        )

    uvicorn.run ensures your application runs on HTTPS using the generated key.pem and cert.pem.

  4. You can now run the API using the following code on your terminal

    python main.py

    URL to the Model API displayed on terminal

    In a production environment, it is recommended to use a reverse proxy server like Nginx to handle SSL termination and forwarded requests to the FastAPI application, to ensure better performance and security.

Encrypting Sensitive Data

Encryption is simply the encoding of sensitive information, such that even if the information were to leak, the content is secured and remains unknown, upon reaching its target destination the data is decoded. This is very useful in protecting sensitive data such as passwords, and only authorized users can decrypt the information using a decryption key. Here is a simple example of how encryption works.

  1. Import all necessary libraries and create an instance of the FastAPI class.

    from fastapi import FastAPI, HTTPException, Depends
    from pydantic import BaseModel
    from cryptography.fernet import Fernet
    
    app = FastAPI()
  2. Next is to generate key for encryption and decryption using the Fernet class.

    # Generate a key for encryption and decryption
    key = Fernet.generate_key()
    cipher_suite = Fernet(key)
  3. Create an Item model for receiving a text, and the EncryptedItem model for receiving the encrypted text.

    # Models
    class Item(BaseModel):
        plaintext: str
    
    class EncryptedItem(BaseModel):
        ciphertext: str
  4. Create the encryption endpoint.

    @app.post("/encrypt/", response_model=EncryptedItem)
    async def encrypt_item(item: Item):
        plaintext = item.plaintext.encode("utf-8")
        ciphertext = cipher_suite.encrypt(plaintext)
        return {"ciphertext": ciphertext.decode("utf-8")}

    This takes the given item and encodes it to utf-8 , the cipher_suite key encrypts the plaintext to ciphertext which is a string of gibberish characters.

  5. Create the decryption endpoint that decrypts the gibberish characters to the plaintext.

    # Decryption endpoint
    @app.post("/decrypt/", response_model=Item)
    async def decrypt_item(encrypted_item: EncryptedItem):
        ciphertext = encrypted_item.ciphertext.encode("utf-8")
        try:
            plaintext = cipher_suite.decrypt(ciphertext)
            return {"plaintext": plaintext.decode("utf-8")}
        except Exception as e:
            raise HTTPException(status_code=400, detail="Decryption failed")

    This endpoint takes the encrypted_item and encodes it to utf-8 before decrypting it to plaintext using the cipher_suite function. If the wrong ciphertext is provided, a 400 status code is returned with the detail "Decryption failed".

    Encrypting sensitive data in FastAPI

Rate Limiting and Throttling

Another way of securing APIs is by limiting the number of API calls made to the server. This is where rate limiting and throttling comes into play. Rate limiting is a technique of controlling the amount of incoming and outgoing traffic to or from a network, to prevent abuse and overloading of the server. While throttling on the other hand is temporarily slowing down the rate at which the API processes requests. To apply rate limiting and throttling to our previous example.

  1. Ensure you have installed the slowapi library, a library for implementing rate-limiting and throttling to APIs, and add the following new imports.

    from slowapi import Limiter, _rate_limit_exceeded_handler
    from slowapi.util import get_remote_address
    from slowapi.errors import RateLimitExceeded
  2. Next is to initialize the rate limiter.

    limiter = Limiter(key_func=get_remote_address)
    app = FastAPI()
    app.state.limiter = limiter
    app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
  3. Apply the rate limiter to the /token/ endpoint using @limiter.limit("5/minute") decorator, and the request: Request parameter in the login_for_access_token function.

    @app.post("/token")
    @limiter.limit("5/minute")
    async def login_for_access_token(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
  4. Also, apply a 10-minute rate limiting to the /predict endpoint. Change the parameter name in the predict function from request to prediction_request to avoid confusion with the new request: Request parameter.

    @app.post("/predict")
    @limiter.limit("10/minute")
    async def predict(
        request: Request,
        prediction_request: PredictionRequest,
        token: str = Depends(oauth2_scheme),
        current_user: User = Depends(role_checker("admin"))

Conclusion

You can combine all these methods in your ML Model API to ensure maximum security as much as possible. In this article, you have learned how to implement various API security techniques in your FastAPI model such as authentication, authorization, input validation, sanitization, encryption, rate limiting, and throttling. If you want to dive deep into model deployment with FastAPI, here are some extra resources to keep you busy.

Need Help with Data? Let’s Make It Simple.

At LearnData.xyz, we’re here to help you solve tough data challenges and make sense of your numbers. Whether you need custom data science solutions or hands-on training to upskill your team, we’ve got your back.

📧 Shoot us an email at admin@learndata.xyz—let’s chat about how we can help you make smarter decisions with your data.

Your next breakthrough could be one email away. Let’s make it happen!